import math

import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
def import_class(name):
    components = name.split('.')
    mod = __import__(components[0])
    for comp in components[1:]:
        mod = getattr(mod, comp)
    return mod


def conv_branch_init(conv, branches):
    weight = conv.weight
    n = weight.size(0)
    k1 = weight.size(1)
    k2 = weight.size(2)
    nn.init.normal_(weight, 0, math.sqrt(2. / (n * k1 * k2 * branches)))
    if conv.bias is not None:
        nn.init.constant_(conv.bias, 0)


def conv_init(conv):
    if conv.weight is not None:
        nn.init.kaiming_normal_(conv.weight, mode='fan_out')
    if conv.bias is not None:
        nn.init.constant_(conv.bias, 0)


def bn_init(bn, scale):
    nn.init.constant_(bn.weight, scale)
    nn.init.constant_(bn.bias, 0)


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        if hasattr(m, 'weight'):
            nn.init.kaiming_normal_(m.weight, mode='fan_out')
        if hasattr(m, 'bias') and m.bias is not None and isinstance(m.bias, torch.Tensor):
            nn.init.constant_(m.bias, 0)
    elif classname.find('BatchNorm') != -1:
        if hasattr(m, 'weight') and m.weight is not None:
            m.weight.data.normal_(1.0, 0.02)
        if hasattr(m, 'bias') and m.bias is not None:
            m.bias.data.fill_(0)



class TemporalConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=8):
        super(TemporalConv, self).__init__()
        pad = (kernel_size + (kernel_size-1) * (dilation-1) - 1) // 2
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=(kernel_size, 1),
            padding=(pad, 0),
            stride=(stride, 1),
            dilation=(dilation, 1),
            groups=groups
            )

        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x

class MultiScale_TemporalConv(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size=3,
                 stride=1,
                 dilations=[1,2,3,4],
                 residual=False,
                 residual_kernel_size=1):

        super().__init__()
        assert out_channels % (len(dilations) + 2) == 0, '# out channels should be multiples of # branches'

        # Multiple branches of temporal convolution
        self.num_branches = len(dilations) + 2
        branch_channels = out_channels // self.num_branches
        if type(kernel_size) == list:
            assert len(kernel_size) == len(dilations)
        else:
            kernel_size = [kernel_size]*len(dilations)
        # Temporal Convolution branches
        self.branches = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(
                    in_channels,
                    branch_channels,
                    kernel_size=1,
                    padding=0),
                nn.BatchNorm2d(branch_channels),
                nn.ReLU(inplace=True),
                TemporalConv(
                    branch_channels,
                    branch_channels,
                    kernel_size=ks,
                    stride=stride,
                    dilation=dilation),
            )
            for ks, dilation in zip(kernel_size, dilations)
        ])

        # Additional Max & 1x1 branch
        self.branches.append(nn.Sequential(
            nn.Conv2d(in_channels, branch_channels, kernel_size=1, padding=0),
            nn.BatchNorm2d(branch_channels),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(3,1), stride=(stride,1), padding=(1,0)),
            nn.BatchNorm2d(branch_channels)  # 为什么还要加bn
        ))

        self.branches.append(nn.Sequential(
            nn.Conv2d(in_channels, branch_channels, kernel_size=1, padding=0, stride=(stride,1)),
            nn.BatchNorm2d(branch_channels)
        ))

        # Residual connection
        if not residual:
            self.residual = lambda x: 0
        elif (in_channels == out_channels) and (stride == 1):
            self.residual = lambda x: x
        else:
            self.residual = TemporalConv(in_channels, out_channels, kernel_size=residual_kernel_size, stride=stride)
        # print(len(self.branches))
        # initialize
        self.apply(weights_init)

    def forward(self, x):
        # Input dim: (N,C,T,V)
        res = self.residual(x)
        branch_outs = []
        for tempconv in self.branches:
            out = tempconv(x)
            branch_outs.append(out)

        # out0 = self.branches[0](x)
        # branch_outs.append(out0)
        # out1 = self.branches[1](x)
        # branch_outs.append(out1)

        out = torch.cat(branch_outs, dim=1)
        out += res
        return out

# class unit_tcn(nn.Module):
#     def __init__(self, in_channels, out_channels, kernel_size=5, stride=1):
#         super(unit_tcn, self).__init__()
#         pad = int((kernel_size - 1) / 2)
#         self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(kernel_size, 1), padding=(pad, 0),
#                               stride=(stride, 1))
#
#         self.bn = nn.BatchNorm2d(out_channels)
#         self.relu = nn.ReLU(inplace=True)
#         conv_init(self.conv)
#         bn_init(self.bn, 1)
#
#     def forward(self, x):
#         x = self.bn(self.conv(x))
#         return x


class fc_chain(nn.Module):
    def __init__(self, in_channels, out_channels, num_joints=25, num_heads=1):
        super(fc_chain, self).__init__()
        self.out_c = out_channels
        self.in_c = in_channels
        self.num_heads = num_heads
        self.fc1 = nn.Parameter(torch.stack([torch.eye(num_joints)]*num_heads, dim=0), requires_grad=True)
        self.fc2 = nn.Conv2d(in_channels, out_channels, 1, groups=num_heads)


        if in_channels != out_channels:
            self.down = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.down = lambda x: x

        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                conv_init(m)
            elif isinstance(m, nn.BatchNorm2d):
                bn_init(m, 1)
        bn_init(self.bn, 1e-6)

    def L2_norm(self, weight):
        weight_norm = torch.norm(weight, 2, dim=1, keepdim=True) + 1e-4  # H, 1, V
        return weight_norm

    def forward(self, x):
        N, C, T, V = x.size()
        weight_norm = self.L2_norm(self.fc1)
        w1 = self.fc1
        w1 = w1/weight_norm
        if self.num_heads != 1:
            x1 = x.view(N, self.num_heads, C//self.num_heads, T, V)
            z = torch.einsum("nhctv, hvw->nhctw", (x1, w1)).contiguous().view(N, -1, T, V)
        else:
            z = torch.einsum("nctv, hvw->nctw", (x, w1))
        z = self.fc2(z)
        y = self.bn(z)
        y += self.down(x)
        y = self.relu(y)

        return y


class FC_TCN_Block(nn.Module):
    def __init__(self, in_channels, out_channels, num_joints, stride=1, residual=True, num_heads=1, kernel_size=5, dilations=[1,2]):
        super(FC_TCN_Block, self).__init__()
        self.fc_chain = fc_chain(in_channels, out_channels, num_joints=num_joints, num_heads=num_heads)
        # self.tcn1 = unit_tcn(out_channels, out_channels, stride=stride)
        self.tcn1 = MultiScale_TemporalConv(out_channels, out_channels, kernel_size=kernel_size, stride=stride,
                                            dilations=dilations,
                                            residual=False)
        self.relu = nn.ReLU(inplace=True)
        if not residual:
            self.residual = lambda x: 0

        elif (in_channels == out_channels) and (stride == 1):
            self.residual = lambda x: x

        else:
            self.residual = unit_tcn(in_channels, out_channels, kernel_size=1, stride=stride)

    def forward(self, x):
        y = self.relu(self.tcn1(self.fc_chain(x)) + self.residual(x))
        return y


class Model(nn.Module):
    def __init__(self, num_class=60, num_point=25, num_person=2, in_channels=3,
                 drop_out=0, num_heads=8):
        super(Model, self).__init__()

        self.num_class = num_class
        self.num_points = num_point
        self.data_bn = nn.BatchNorm1d(num_person * in_channels * num_point)

        self.l1 = FC_TCN_Block(3, 128, num_point, residual=False, num_heads=1)
        self.l2 = FC_TCN_Block(128, 128, num_point, num_heads=num_heads)
        self.l3 = FC_TCN_Block(128, 128, num_point, num_heads=num_heads)
        self.l4 = FC_TCN_Block(128, 128, num_point, num_heads=num_heads)
        self.l5 = FC_TCN_Block(128, 256, num_point, stride=2, num_heads=num_heads)
        self.l6 = FC_TCN_Block(256, 256, num_point, num_heads=num_heads)
        self.l7 = FC_TCN_Block(256, 256, num_point, num_heads=num_heads)
        self.l8 = FC_TCN_Block(256, 256, num_point, num_heads=num_heads)
        self.l9 = FC_TCN_Block(256, 512, num_point, stride=2, num_heads=num_heads)
        self.l10 = FC_TCN_Block(512, 512, num_point, num_heads=num_heads)
        self.l11 = FC_TCN_Block(512, 512, num_point, num_heads=num_heads)
        self.l12 = FC_TCN_Block(512, 512, num_point, num_heads=num_heads)
        self.fc = nn.Linear(512, num_class)
        nn.init.normal_(self.fc.weight, 0, math.sqrt(2. / num_class))
        bn_init(self.data_bn, 1)
        if drop_out:
            self.drop_out = nn.Dropout(drop_out)
        else:
            self.drop_out = lambda x: x

    def forward(self, x, y):
        N, C, T, V, M = x.size()
        x = x.permute(0, 4, 3, 1, 2).contiguous().view(N, M * V * C, T)
        x = self.data_bn(x)
        x = x.view(N, M, V, C, T).permute(0, 1, 3, 4, 2).contiguous().view(N * M, C, T, V)
        x = self.l1(x)
        x = self.l2(x)
        x = self.l3(x)
        x = self.l4(x)
        x = self.l5(x)
        x = self.l6(x)
        x = self.l7(x)
        x = self.l8(x)
        x = self.l9(x)
        x = self.l10(x)
        x = self.l11(x)
        x = self.l12(x)

        # N*M,C,T,V
        c_new = x.size(1)
        x = x.view(N, M, c_new, -1)
        x = x.mean(3).mean(1)
        x = self.drop_out(x)

        return self.fc(x), y
